{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Import some libraries\n",
    "\n",
    "import torch\n",
    "import torchvision\n",
    "from torch import nn\n",
    "from torch.utils.data import DataLoader\n",
    "from torchvision import transforms\n",
    "from torchvision.datasets import MNIST\n",
    "from matplotlib import pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Convert vector to image\n",
    "\n",
    "def to_img(x):\n",
    "    x = 0.5 * (x + 1)\n",
    "    x = x.view(x.size(0), 28, 28)\n",
    "    return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Displaying routine\n",
    "\n",
    "def display_images(in_, out, n=1):\n",
    "    for N in range(n):\n",
    "        if in_ is not None:\n",
    "            in_pic = to_img(in_.cpu().data)\n",
    "            plt.figure(figsize=(18, 6))\n",
    "            for i in range(4):\n",
    "                plt.subplot(1,4,i+1)\n",
    "                plt.imshow(in_pic[i+4*N])\n",
    "                plt.axis('off')\n",
    "        out_pic = to_img(out.cpu().data)\n",
    "        plt.figure(figsize=(18, 6))\n",
    "        for i in range(4):\n",
    "            plt.subplot(1,4,i+1)\n",
    "            plt.imshow(out_pic[i+4*N])\n",
    "            plt.axis('off')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define data loading step\n",
    "\n",
    "batch_size = 256\n",
    "\n",
    "img_transform = transforms.Compose([\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize((0.5,), (0.5,))\n",
    "])\n",
    "\n",
    "dataset = MNIST('./data', transform=img_transform, download=True)\n",
    "dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define model architecture and reconstruction loss\n",
    "\n",
    "# n = 28 x 28 = 784\n",
    "d = 30  # for standard AE (under-complete hidden layer)\n",
    "# d = 500  # for denoising AE (over-complete hidden layer)\n",
    "\n",
    "class Autoencoder(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.encoder = nn.Sequential(\n",
    "            nn.Linear(28 * 28, d),\n",
    "            nn.Tanh(),\n",
    "        )\n",
    "        self.decoder = nn.Sequential(\n",
    "            nn.Linear(d, 28 * 28),\n",
    "            nn.Tanh(),\n",
    "        )\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.encoder(x)\n",
    "        x = self.decoder(x)\n",
    "        return x\n",
    "    \n",
    "model = Autoencoder().to(device)\n",
    "criterion = nn.MSELoss()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Configure the optimiser\n",
    "\n",
    "learning_rate = 1e-3\n",
    "\n",
    "optimizer = torch.optim.Adam(\n",
    "    model.parameters(),\n",
    "    lr=learning_rate,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "*Comment* or *un-comment out* a few lines of code to seamlessly switch between *standard AE* and *denoising one*.\n",
    "\n",
    "Don't forget to **(1)** change the size of the hidden layer accordingly, **(2)** re-generate the model, and **(3)** re-pass the parameters to the optimiser."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Train standard or denoising autoencoder (AE)\n",
    "\n",
    "num_epochs = 1\n",
    "# do = nn.Dropout()  # comment out for standard AE\n",
    "for epoch in range(num_epochs):\n",
    "    for data in dataloader:\n",
    "        img, _ = data\n",
    "        img = img.to(device)\n",
    "        img.requires_grad_()\n",
    "        img = img.view(img.size(0), -1)\n",
    "#         noise = do(torch.ones(img.shape))\n",
    "#         img_bad = (img * noise).to(device)  # comment out for standard AE\n",
    "        # ===================forward=====================\n",
    "        output = model(img)  # feed <img> (for std AE) or <img_bad> (for denoising AE)\n",
    "        loss = criterion(output, img.data)\n",
    "        # ===================backward====================\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "    # ===================log========================\n",
    "    print(f'epoch [{epoch + 1}/{num_epochs}], loss:{loss.item():.4f}')\n",
    "    display_images(None, output)  # pass (None, output) for std AE, (img_bad, output) for denoising AE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Visualise a few kernels of the encoder\n",
    "\n",
    "display_images(None, model.encoder[0].weight, 5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Let's compare the autoencoder inpainting capabilities vs. OpenCV\n",
    "\n",
    "from cv2 import inpaint, INPAINT_NS, INPAINT_TELEA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Inpaint with Telea and Navier-Stokes methods\n",
    "\n",
    "dst_TELEA = list()\n",
    "dst_NS = list()\n",
    "\n",
    "for i in range(3, 7):\n",
    "    corrupted_img = ((img_bad.data.cpu()[i].view(28, 28) / 4 + 0.5) * 255).byte().numpy()\n",
    "    mask = 2 - noise.cpu()[i].view(28, 28).byte().numpy()\n",
    "    dst_TELEA.append(inpaint(corrupted_img, mask, 3, INPAINT_TELEA))\n",
    "    dst_NS.append(inpaint(corrupted_img, mask, 3, INPAINT_NS))\n",
    "\n",
    "tns_TELEA = [torch.from_numpy(d) for d in dst_TELEA]\n",
    "tns_NS = [torch.from_numpy(d) for d in dst_NS]\n",
    "\n",
    "TELEA = torch.stack(tns_TELEA).float()\n",
    "NS = torch.stack(tns_NS).float()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Compare the results: [noise], [img + noise], [img], [AE, Telea, Navier-Stokes] inpainting\n",
    "\n",
    "with torch.no_grad():\n",
    "    display_images(noise[3:7], img_bad[3:7])\n",
    "    display_images(img[3:7], output[3:7])\n",
    "    display_images(TELEA, NS)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:dl-minicourse] *",
   "language": "python",
   "name": "conda-env-dl-minicourse-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
